{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Concise Implementation of Softmax Regression\n",
    ":label:`sec_softmax_djl`\n",
    "\n",
    "Just as DJL made it much easier\n",
    "to implement linear regression in :numref:`sec_linear_djl`,\n",
    "we will find it similarly (or possibly more)\n",
    "convenient for implementing classification models.\n",
    "Again, we begin with our import ritual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "\n",
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import ai.djl.metric.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us stick with the Fashion-MNIST dataset \n",
    "and keep the batch size at $256$ as in the last section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "boolean randomShuffle = true;\n",
    "\n",
    "// Get Training and Validation Datasets\n",
    "FashionMnist trainingSet = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, randomShuffle)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "\n",
    "FashionMnist validationSet = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, false)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializing Model Parameters\n",
    "\n",
    "As mentioned in :numref:`sec_softmax`,\n",
    "the output layer of softmax regression \n",
    "is a fully-connected (`Dense`) layer.\n",
    "Therefore, to implement our model,\n",
    "we just need to add one `Dense` layer \n",
    "with 10 outputs to our `Sequential`.\n",
    "Again, here, the `Sequential` is not really necessary,\n",
    "but we might as well form the habit since it will be ubiquitous\n",
    "when implementing deep models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class ActivationFunction {\n",
    "    public static NDList softmax(NDList arrays) {\n",
    "        return new NDList(arrays.singletonOrThrow().logSoftmax(1));\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "Model model = Model.newInstance(\"softmax-regression\");\n",
    "\n",
    "SequentialBlock net = new SequentialBlock();\n",
    "net.add(Blocks.batchFlattenBlock(28 * 28)); // flatten input\n",
    "net.add(Linear.builder().setUnits(10).build()); // set 10 output channels\n",
    "\n",
    "model.setBlock(net);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Softmax\n",
    "\n",
    "In the previous example, we calculated our model's output\n",
    "and then ran this output through the cross-entropy loss.\n",
    "Mathematically, that is a perfectly reasonable thing to do.\n",
    "However, from a computational perspective, \n",
    "exponentiation can be a source of numerical stability issues\n",
    "(as discussed  in :numref:`sec_naive_bayes`).\n",
    "Recall that the softmax function calculates\n",
    "$\\hat y_j = \\frac{e^{z_j}}{\\sum_{i=1}^{n} e^{z_i}}$, \n",
    "where $\\hat y_j$ is the $j^\\mathrm{th}$ element of ``yHat`` \n",
    "and $z_j$ is the $j^\\mathrm{th}$ element of the input\n",
    "``yLinear`` variable, as computed by the softmax.\n",
    "\n",
    "If some of the $z_i$ are very large (i.e., very positive),\n",
    "then $e^{z_i}$ might be larger than the largest number\n",
    "we can have for certain types of ``float`` (i.e., overflow).\n",
    "This would make the denominator (and/or numerator) ``inf`` \n",
    "and we wind up encountering either $0$, ``inf``, or ``nan`` for $\\hat y_j$.\n",
    "In these situations we do not get a well-defined \n",
    "return value for ``crossEntropy()``.\n",
    "One trick to get around this is to first subtract $\\text{max}(z_i)$\n",
    "from all $z_i$ before proceeding with the ``softmax`` calculation.\n",
    "You can verify that this shifting of each $z_i$ by constant factor\n",
    "does not change the return value of ``softmax()``.\n",
    "\n",
    "After the subtraction and normalization step,\n",
    "it might be possible that some $z_j$ have large negative values\n",
    "and thus that the corresponding $e^{z_j}$ will take values close to zero.\n",
    "These might be rounded to zero due to finite precision (i.e underflow),\n",
    "making $\\hat y_j$ zero and giving us ``-inf`` for $\\text{log}(\\hat y_j)$.\n",
    "A few steps down the road in backpropagation,\n",
    "we might find ourselves faced with a screenful \n",
    "of the dreaded not-a-number (``nan``) results.\n",
    "\n",
    "Fortunately, we are saved by the fact that \n",
    "even though we are computing exponential functions, \n",
    "we ultimately intend to take their log \n",
    "(when calculating the cross-entropy loss).\n",
    "By combining these two operators \n",
    "(``softmax`` and ``crossEntropy``) together,\n",
    "we can escape the numerical stability issues\n",
    "that might otherwise plague us during backpropagation.\n",
    "As shown in the equation below, we avoided calculating $e^{z_j}$\n",
    "and can use instead $z_j$ directly due to the canceling in $\\log(\\exp(\\cdot))$.\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "\\log{(\\hat y_j)} & = \\log\\left( \\frac{e^{z_j}}{\\sum_{i=1}^{n} e^{z_i}}\\right) \\\\\n",
    "& = \\log{(e^{z_j})}-\\text{log}{\\left( \\sum_{i=1}^{n} e^{z_i} \\right)} \\\\\n",
    "& = z_j -\\log{\\left( \\sum_{i=1}^{n} e^{z_i} \\right)}.\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "We will want to keep the conventional softmax function handy\n",
    "in case we ever want to evaluate the probabilities output by our model.\n",
    "But instead of passing softmax probabilities into our new loss function,\n",
    "we will just pass the logits and compute the softmax and its log\n",
    "all at once inside the softmaxCrossEntropy loss function,\n",
    "which does smart things like the log-sum-exp trick ([see on Wikipedia](https://en.wikipedia.org/wiki/LogSumExp))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "4"
    }
   },
   "outputs": [],
   "source": [
    "Loss loss = Loss.softmaxCrossEntropyLoss();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimization Algorithm\n",
    "\n",
    "Here, we use minibatch stochastic gradient descent\n",
    "with a learning rate of $0.1$ as the optimization algorithm.\n",
    "Note that this is the same as we applied in the linear regression example\n",
    "and it illustrates the general applicability of the optimizers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [],
   "source": [
    "Tracker lrt = Tracker.fixed(0.1f);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Instantiate Configuration\n",
    "Now we'll create a training configuration that\n",
    "describes how we want to train our model.\n",
    "We will then create a `trainer` to do the\n",
    "training for us."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "    .optOptimizer(sgd) // Optimizer\n",
    "    .optDevices(manager.getEngine().getDevices(1)) // single GPU\n",
    "    .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "    .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializing Trainer\n",
    "We initialize the trainer with input shape ($1$, $748$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.initialize(new Shape(1, 28 * 28)); // Input Images are 28 x 28"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics\n",
    "Now we tell DJL to record metrics! (Remember, DJL doesn't record metrics unless you tell it to!)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Metrics metrics = new Metrics();\n",
    "trainer.setMetrics(metrics);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "In :numref:`sec_linear_djl`, we train the model by explicitly calling `EasyTrain` to train each batch and then updating the parameters. We can actually instead call `EasyTrain`'s `fit()` function to do this for us in 1 line. It takes in a given number of epochs, a training set, and a validation set and, along with the training, will do the validation for us as well!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "int numEpochs = 3;\n",
    "\n",
    "EasyTrain.fit(trainer, numEpochs, trainingSet, validationSet);\n",
    "var result = trainer.getTrainingResult();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As before, this algorithm converges to a solution\n",
    "that achieves an accuracy of 83.7%,\n",
    "albeit this time with fewer lines of code than before.\n",
    "Note that in many cases, DJL takes additional precautions\n",
    "beyond these most well-known tricks to ensure numerical stability,\n",
    "saving us from even more pitfalls that we would encounter\n",
    "if we tried to code all of our models from scratch in practice.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Try adjusting the hyper-parameters, such as batch size, epoch, and learning rate, to see what the results are.\n",
    "1. Why might the test accuracy decrease again after a while? How could we fix this?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
